{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/milvus-io/bootcamp/blob/master/bootcamp/tutorials/quickstart/full_text_search_with_milvus.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>   <a href=\"https://github.com/milvus-io/bootcamp/blob/master/bootcamp/tutorials/quickstart/full_text_search_with_milvus.ipynb\" target=\"_blank\">\n",
    "    <img src=\"https://img.shields.io/badge/View%20on%20GitHub-555555?style=flat&logo=github&logoColor=white\" alt=\"GitHub Repository\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Full Text Search with Milvus\n",
    "\n",
    "With the release of Milvus 2.5, Full Text Search enables users to efficiently search for text based on keywords or phrases, providing powerful text retrieval capabilities. This feature enhances search accuracy and can be seamlessly combined with embedding-based retrieval for hybrid search, allowing for both semantic and keyword-based results in a single query. In this notebook, we will show basic usage of full text search in Milvus."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preparation\n",
    "\n",
    "### Download the dataset\n",
    "The following command will download the example data used in original Anthropic [demo](https://github.com/anthropics/anthropic-cookbook/blob/main/skills/contextual-embeddings/guide.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "powershell"
    }
   },
   "outputs": [],
   "source": [
    "!wget https://raw.githubusercontent.com/anthropics/anthropic-cookbook/refs/heads/main/skills/contextual-embeddings/data/codebase_chunks.json\n",
    "!wget https://raw.githubusercontent.com/anthropics/anthropic-cookbook/refs/heads/main/skills/contextual-embeddings/data/evaluation_set.jsonl"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Install Milvus 2.5\n",
    "Check the [official installation guide](https://milvus.io/docs/install_standalone-docker-compose.md) for more details.\n",
    "\n",
    "### Install PyMilvus\n",
    "Run the following command to install PyMilvus:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "powershell"
    }
   },
   "outputs": [],
   "source": [
    "pip install \"pymilvus[model]\" -U "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define the Retriever"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "from pymilvus import (\n",
    "    MilvusClient,\n",
    "    DataType,\n",
    "    Function,\n",
    "    FunctionType,\n",
    "    AnnSearchRequest,\n",
    "    RRFRanker,\n",
    ")\n",
    "\n",
    "from pymilvus.model.hybrid import BGEM3EmbeddingFunction\n",
    "\n",
    "\n",
    "class HybridRetriever:\n",
    "    def __init__(self, uri, collection_name=\"hybrid\", dense_embedding_function=None):\n",
    "        self.uri = uri\n",
    "        self.collection_name = collection_name\n",
    "        self.embedding_function = dense_embedding_function\n",
    "        self.use_reranker = True\n",
    "        self.use_sparse = True\n",
    "        self.client = MilvusClient(uri=uri)\n",
    "\n",
    "    def build_collection(self):\n",
    "        if isinstance(self.embedding_function.dim, dict):\n",
    "            dense_dim = self.embedding_function.dim[\"dense\"]\n",
    "        else:\n",
    "            dense_dim = self.embedding_function.dim\n",
    "\n",
    "        tokenizer_params = {\n",
    "            \"tokenizer\": \"standard\",\n",
    "            \"filter\": [\n",
    "                \"lowercase\",\n",
    "                {\n",
    "                    \"type\": \"length\",\n",
    "                    \"max\": 200,\n",
    "                },\n",
    "                {\"type\": \"stemmer\", \"language\": \"english\"},\n",
    "                {\n",
    "                    \"type\": \"stop\",\n",
    "                    \"stop_words\": [\n",
    "                        \"a\",\n",
    "                        \"an\",\n",
    "                        \"and\",\n",
    "                        \"are\",\n",
    "                        \"as\",\n",
    "                        \"at\",\n",
    "                        \"be\",\n",
    "                        \"but\",\n",
    "                        \"by\",\n",
    "                        \"for\",\n",
    "                        \"if\",\n",
    "                        \"in\",\n",
    "                        \"into\",\n",
    "                        \"is\",\n",
    "                        \"it\",\n",
    "                        \"no\",\n",
    "                        \"not\",\n",
    "                        \"of\",\n",
    "                        \"on\",\n",
    "                        \"or\",\n",
    "                        \"such\",\n",
    "                        \"that\",\n",
    "                        \"the\",\n",
    "                        \"their\",\n",
    "                        \"then\",\n",
    "                        \"there\",\n",
    "                        \"these\",\n",
    "                        \"they\",\n",
    "                        \"this\",\n",
    "                        \"to\",\n",
    "                        \"was\",\n",
    "                        \"will\",\n",
    "                        \"with\",\n",
    "                    ],\n",
    "                },\n",
    "            ],\n",
    "        }\n",
    "\n",
    "        schema = MilvusClient.create_schema()\n",
    "        schema.add_field(\n",
    "            field_name=\"pk\",\n",
    "            datatype=DataType.VARCHAR,\n",
    "            is_primary=True,\n",
    "            auto_id=True,\n",
    "            max_length=100,\n",
    "        )\n",
    "        schema.add_field(\n",
    "            field_name=\"content\",\n",
    "            datatype=DataType.VARCHAR,\n",
    "            max_length=65535,\n",
    "            analyzer_params=tokenizer_params,\n",
    "            enable_match=True,\n",
    "            enable_analyzer=True,\n",
    "        )\n",
    "        schema.add_field(\n",
    "            field_name=\"sparse_vector\", datatype=DataType.SPARSE_FLOAT_VECTOR\n",
    "        )\n",
    "        schema.add_field(\n",
    "            field_name=\"dense_vector\", datatype=DataType.FLOAT_VECTOR, dim=dense_dim\n",
    "        )\n",
    "        schema.add_field(\n",
    "            field_name=\"original_uuid\", datatype=DataType.VARCHAR, max_length=128\n",
    "        )\n",
    "        schema.add_field(field_name=\"doc_id\", datatype=DataType.VARCHAR, max_length=64)\n",
    "        schema.add_field(\n",
    "            field_name=\"chunk_id\", datatype=DataType.VARCHAR, max_length=64\n",
    "        ),\n",
    "        schema.add_field(field_name=\"original_index\", datatype=DataType.INT32)\n",
    "\n",
    "        functions = Function(\n",
    "            name=\"bm25\",\n",
    "            function_type=FunctionType.BM25,\n",
    "            input_field_names=[\"content\"],\n",
    "            output_field_names=\"sparse_vector\",\n",
    "        )\n",
    "\n",
    "        schema.add_function(functions)\n",
    "\n",
    "        index_params = MilvusClient.prepare_index_params()\n",
    "        index_params.add_index(\n",
    "            field_name=\"sparse_vector\",\n",
    "            index_type=\"SPARSE_INVERTED_INDEX\",\n",
    "            metric_type=\"BM25\",\n",
    "        )\n",
    "        index_params.add_index(\n",
    "            field_name=\"dense_vector\", index_type=\"FLAT\", metric_type=\"IP\"\n",
    "        )\n",
    "\n",
    "        self.client.create_collection(\n",
    "            collection_name=self.collection_name,\n",
    "            schema=schema,\n",
    "            index_params=index_params,\n",
    "        )\n",
    "\n",
    "    def insert_data(self, chunk, metadata):\n",
    "        embedding = self.embedding_function([chunk])\n",
    "        if isinstance(embedding, dict) and \"dense\" in embedding:\n",
    "            dense_vec = embedding[\"dense\"][0]\n",
    "        else:\n",
    "            dense_vec = embedding[0]\n",
    "        self.client.insert(\n",
    "            self.collection_name, {\"dense_vector\": dense_vec, **metadata}\n",
    "        )\n",
    "\n",
    "    def search(self, query: str, k: int = 20, mode=\"hybrid\"):\n",
    "\n",
    "        output_fields = [\n",
    "            \"content\",\n",
    "            \"original_uuid\",\n",
    "            \"doc_id\",\n",
    "            \"chunk_id\",\n",
    "            \"original_index\",\n",
    "        ]\n",
    "        if mode in [\"dense\", \"hybrid\"]:\n",
    "            embedding = self.embedding_function([query])\n",
    "            if isinstance(embedding, dict) and \"dense\" in embedding:\n",
    "                dense_vec = embedding[\"dense\"][0]\n",
    "            else:\n",
    "                dense_vec = embedding[0]\n",
    "\n",
    "        if mode == \"sparse\":\n",
    "            results = self.client.search(\n",
    "                collection_name=self.collection_name,\n",
    "                data=[query],\n",
    "                anns_field=\"sparse_vector\",\n",
    "                limit=k,\n",
    "                output_fields=output_fields,\n",
    "            )\n",
    "        elif mode == \"dense\":\n",
    "            results = self.client.search(\n",
    "                collection_name=self.collection_name,\n",
    "                data=[dense_vec],\n",
    "                anns_field=\"dense_vector\",\n",
    "                limit=k,\n",
    "                output_fields=output_fields,\n",
    "            )\n",
    "        elif mode == \"hybrid\":\n",
    "            full_text_search_params = {\"metric_type\": \"BM25\"}\n",
    "            full_text_search_req = AnnSearchRequest(\n",
    "                [query], \"sparse_vector\", full_text_search_params, limit=k\n",
    "            )\n",
    "\n",
    "            dense_search_params = {\"metric_type\": \"IP\"}\n",
    "            dense_req = AnnSearchRequest(\n",
    "                [dense_vec], \"dense_vector\", dense_search_params, limit=k\n",
    "            )\n",
    "\n",
    "            results = self.client.hybrid_search(\n",
    "                self.collection_name,\n",
    "                [full_text_search_req, dense_req],\n",
    "                ranker=RRFRanker(),\n",
    "                limit=k,\n",
    "                output_fields=output_fields,\n",
    "            )\n",
    "        else:\n",
    "            raise ValueError(\"Invalid mode\")\n",
    "        return [\n",
    "            {\n",
    "                \"doc_id\": doc[\"entity\"][\"doc_id\"],\n",
    "                \"chunk_id\": doc[\"entity\"][\"chunk_id\"],\n",
    "                \"content\": doc[\"entity\"][\"content\"],\n",
    "                \"score\": doc[\"distance\"],\n",
    "            }\n",
    "            for doc in results[0]\n",
    "        ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 108848.72it/s]\n"
     ]
    }
   ],
   "source": [
    "dense_ef = BGEM3EmbeddingFunction()\n",
    "standard_retriever = HybridRetriever(\n",
    "    uri=\"http://localhost:19530\",\n",
    "    collection_name=\"milvus_hybrid\",\n",
    "    dense_embedding_function=dense_ef,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Insert the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = \"codebase_chunks.json\"\n",
    "with open(path, \"r\") as f:\n",
    "    dataset = json.load(f)\n",
    "\n",
    "is_insert = True\n",
    "if is_insert:\n",
    "    standard_retriever.build_collection()\n",
    "    for doc in dataset:\n",
    "        doc_content = doc[\"content\"]\n",
    "        for chunk in doc[\"chunks\"]:\n",
    "            metadata = {\n",
    "                \"doc_id\": doc[\"doc_id\"],\n",
    "                \"original_uuid\": doc[\"original_uuid\"],\n",
    "                \"chunk_id\": chunk[\"chunk_id\"],\n",
    "                \"original_index\": chunk[\"original_index\"],\n",
    "                \"content\": chunk[\"content\"],\n",
    "            }\n",
    "            chunk_content = chunk[\"content\"]\n",
    "            standard_retriever.insert_data(chunk_content, metadata)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test Sparse Search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'doc_id': 'doc_10', 'chunk_id': 'doc_10_chunk_0', 'content': 'use {\\n    crate::args::LogArgs,\\n    anyhow::{anyhow, Result},\\n    simplelog::{Config, LevelFilter, WriteLogger},\\n    std::fs::File,\\n};\\n\\npub struct Logger;\\n\\nimpl Logger {\\n    pub fn init(args: &impl LogArgs) -> Result<()> {\\n        let filter: LevelFilter = args.log_level().into();\\n        if filter != LevelFilter::Off {\\n            let logfile = File::create(args.log_file())\\n                .map_err(|e| anyhow!(\"Failed to open log file: {e:}\"))?;\\n            WriteLogger::init(filter, Config::default(), logfile)\\n                .map_err(|e| anyhow!(\"Failed to initalize logger: {e:}\"))?;\\n        }\\n        Ok(())\\n    }\\n}\\n', 'score': 9.12518310546875}, {'doc_id': 'doc_87', 'chunk_id': 'doc_87_chunk_3', 'content': '\\t\\tLoggerPtr INF = Logger::getLogger(LOG4CXX_TEST_STR(\"INF\"));\\n\\t\\tINF->setLevel(Level::getInfo());\\n\\n\\t\\tLoggerPtr INF_ERR = Logger::getLogger(LOG4CXX_TEST_STR(\"INF.ERR\"));\\n\\t\\tINF_ERR->setLevel(Level::getError());\\n\\n\\t\\tLoggerPtr DEB = Logger::getLogger(LOG4CXX_TEST_STR(\"DEB\"));\\n\\t\\tDEB->setLevel(Level::getDebug());\\n\\n\\t\\t// Note: categories with undefined level\\n\\t\\tLoggerPtr INF_UNDEF = Logger::getLogger(LOG4CXX_TEST_STR(\"INF.UNDEF\"));\\n\\t\\tLoggerPtr INF_ERR_UNDEF = Logger::getLogger(LOG4CXX_TEST_STR(\"INF.ERR.UNDEF\"));\\n\\t\\tLoggerPtr UNDEF = Logger::getLogger(LOG4CXX_TEST_STR(\"UNDEF\"));\\n\\n', 'score': 7.0077056884765625}, {'doc_id': 'doc_89', 'chunk_id': 'doc_89_chunk_3', 'content': 'using namespace log4cxx;\\nusing namespace log4cxx::helpers;\\n\\nLOGUNIT_CLASS(FMTTestCase)\\n{\\n\\tLOGUNIT_TEST_SUITE(FMTTestCase);\\n\\tLOGUNIT_TEST(test1);\\n\\tLOGUNIT_TEST(test1_expanded);\\n\\tLOGUNIT_TEST(test10);\\n//\\tLOGUNIT_TEST(test_date);\\n\\tLOGUNIT_TEST_SUITE_END();\\n\\n\\tLoggerPtr root;\\n\\tLoggerPtr logger;\\n\\npublic:\\n\\tvoid setUp()\\n\\t{\\n\\t\\troot = Logger::getRootLogger();\\n\\t\\tMDC::clear();\\n\\t\\tlogger = Logger::getLogger(LOG4CXX_TEST_STR(\"java.org.apache.log4j.PatternLayoutTest\"));\\n\\t}\\n\\n', 'score': 6.750633716583252}]\n"
     ]
    }
   ],
   "source": [
    "results = standard_retriever.search(\"create a logger?\", mode=\"sparse\", k=3)\n",
    "print(results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation\n",
    "Now that we have inserted the dataset into Milvus, we can use dense, sparse, or hybrid search to retrieve the top 5 results. You can change the `mode` and evaluate each one. We present the Pass@5 metric, which involves retrieving the top 5 results for each query and calculating the Recall."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_jsonl(file_path: str):\n",
    "    \"\"\"Load JSONL file and return a list of dictionaries.\"\"\"\n",
    "    with open(file_path, \"r\") as file:\n",
    "        return [json.loads(line) for line in file]\n",
    "\n",
    "\n",
    "dataset = load_jsonl(\"evaluation_set.jsonl\")\n",
    "k = 5\n",
    "\n",
    "# mode can be \"dense\", \"sparse\" or \"hybrid\".\n",
    "mode = \"hybrid\"\n",
    "\n",
    "total_query_score = 0\n",
    "num_queries = 0\n",
    "\n",
    "for query_item in dataset:\n",
    "\n",
    "    query = query_item[\"query\"]\n",
    "\n",
    "    golden_chunk_uuids = query_item[\"golden_chunk_uuids\"]\n",
    "\n",
    "    chunks_found = 0\n",
    "    golden_contents = []\n",
    "    for doc_uuid, chunk_index in golden_chunk_uuids:\n",
    "        golden_doc = next(\n",
    "            (doc for doc in query_item[\"golden_documents\"] if doc[\"uuid\"] == doc_uuid),\n",
    "            None,\n",
    "        )\n",
    "        if golden_doc:\n",
    "            golden_chunk = next(\n",
    "                (\n",
    "                    chunk\n",
    "                    for chunk in golden_doc[\"chunks\"]\n",
    "                    if chunk[\"index\"] == chunk_index\n",
    "                ),\n",
    "                None,\n",
    "            )\n",
    "            if golden_chunk:\n",
    "                golden_contents.append(golden_chunk[\"content\"].strip())\n",
    "\n",
    "    results = standard_retriever.search(query, mode=mode, k=5)\n",
    "\n",
    "    for golden_content in golden_contents:\n",
    "        for doc in results[:k]:\n",
    "            retrieved_content = doc[\"content\"].strip()\n",
    "            if retrieved_content == golden_content:\n",
    "                chunks_found += 1\n",
    "                break\n",
    "\n",
    "    query_score = chunks_found / len(golden_contents)\n",
    "\n",
    "    total_query_score += query_score\n",
    "    num_queries += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Pass@5:  0.7911386328725037\n"
     ]
    }
   ],
   "source": [
    "print(\"Pass@5: \", total_query_score / num_queries)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "runtime",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
